{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add custom modules in src directory to Jupyter path\n",
    "import sys\n",
    "sys.path.append(\"src\")\n",
    "\n",
    "# Import standard modules\n",
    "import pandas as pd\n",
    "import torch.optim as optim\n",
    "from collections import defaultdict\n",
    "\n",
    "# Import custom modules\n",
    "from setup import read_config\n",
    "from setup import setup_datasets\n",
    "from setup import setup_embedding_matrix\n",
    "from setup import setup_model\n",
    "from setup import setup_word_vectors\n",
    "from process import texts_to_features\n",
    "from utils import load_checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Required inputs (change these as necessary)\n",
    "CONFIG_PATH = \"/home/cody/abcnn/configs/quora/abcnn1.yaml\"\n",
    "CHECKPOINT_PATH = \"/home/cody/abcnn/checkpoints/quora/word2vec/google_news/abcnn1/best_checkpoint\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read in the config file\n",
    "config = read_config(CONFIG_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "quora_train: 100%|██████████| 283003/283003 [01:44<00:00, 2702.51it/s]\n",
      "quora_val: 100%|██████████| 80049/80049 [00:30<00:00, 2649.83it/s]\n",
      "quora_test: 100%|██████████| 41238/41238 [00:15<00:00, 2715.52it/s]\n"
     ]
    }
   ],
   "source": [
    "# Setup the datasets\n",
    "data_paths = config[\"data_paths\"]\n",
    "embeddings_size = config[\"embeddings\"][\"size\"]\n",
    "max_length = config[\"model\"][\"max_length\"]\n",
    "datasets = {name: pd.read_csv(data_path) for name, data_path in data_paths.items()}\n",
    "datasets, texts, word2index = setup_datasets(datasets, embeddings_size, max_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Word2Vec word vectors from: /home/cody/abcnn/embeddings/word2vec/google_news/GoogleNews-vectors-negative300.bin.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "embedding matrix: 100%|██████████| 85856/85856 [00:00<00:00, 354274.10it/s]\n"
     ]
    }
   ],
   "source": [
    "# Setup the embedding matrix\n",
    "word_vectors = setup_word_vectors(config)\n",
    "embeddings = setup_embedding_matrix(word_vectors, word2index, embeddings_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating the ABCNN model...\n"
     ]
    }
   ],
   "source": [
    "# Setup the model and history dict\n",
    "model = setup_model(embeddings, config)\n",
    "state = load_checkpoint(CHECKPOINT_PATH)\n",
    "model_dict, _, history, _ = state\n",
    "model.load_state_dict(model_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup the loss function and optimizer\n",
    "optimizer = \\\n",
    "    optim.Adagrad(\n",
    "        filter(lambda p: p.requires_grad, model.parameters()),\n",
    "        lr=config[\"optim\"][\"lr\"],\n",
    "        weight_decay=config[\"optim\"][\"weight_decay\"]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make predictions\n",
    "examples = [\n",
    "    (\"How do I connect to VPN?\", \"I need connecting to VPN\"),\n",
    "    (\"What's the wifi password?\", \"I want to get online.\")\n",
    "]\n",
    "\n",
    "featurized_examples, processed_texts = \\\n",
    "    texts_to_features(examples, word_vectors, embeddings_size, max_length)\n",
    "scores = model(featurized_examples)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
